-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX integration tests #1685
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov Report
@@ Coverage Diff @@
## master #1685 +/- ##
=======================================
Coverage 98.90% 98.90%
=======================================
Files 206 206
Lines 15388 15396 +8
=======================================
+ Hits 15219 15227 +8
Misses 169 169
Continue to review full report at Codecov.
|
@@ -279,6 +279,6 @@ def requires_grad(tensor, interface=None): | |||
if interface == "jax": | |||
import jax | |||
|
|||
return isinstance(tensor, jax.interpreters.ad.JVPTracer) | |||
return isinstance(tensor, jax.core.Tracer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.core.Tracer
is the original parent class, so this is a lot safer :) There are cases I discovered where JAX will use tracers that aren't JVPTracer
.
@@ -400,7 +400,7 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su | |||
coeffs = coeffs.numpy() | |||
|
|||
expected = np.dot(qcval, coeffs) | |||
assert np.all(res == expected) | |||
assert np.allclose(res, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, this test was failing for me on CI (but not locally) 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me! 🙂 Left some questions, but no major blocker.
If I recall correctly, we plan to discontinue support for QNodes with ragged outputs (right? 🤔). Would that affect the failing case of Ragged QNodes in backprop mode?
if interface == "jax": | ||
import jax | ||
|
||
if not any(isinstance(v, jax.core.Tracer) for v in values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a great change to have 🥇
How come it's placed here, instead of into the JAX branch of requires_grad
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it needs to be here, since the not any
check can only be done here, it cannot be done inside the requires_grad
check (which only checks a single tensor at a time) 🤔
I could be wrong though, let me know if you see a way around this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think you're right. 🤔 At least nothing else comes to mind that we could use here.
qml.CNOT(wires=[0, 1]) | ||
return qml.expval(qml.PauliY(1)) | ||
|
||
res = jax.grad(cost_fn, argnums=[0, 1])(a, b, shots=30000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the test case make sure that we had shots=30000
instead of shots=100
? Would the deviation from the expected analytic result be even bigger than 0.1
with shots=100
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is not the most ideal way of testing this, but I spent a while and couldn't come up with anything better 🤔
The other interfaces use shots=[(1, 1000)]
, which is nicer since the output shape changes. However, it can't be used for JAX, since JAX only supports scalar outputs :(
assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift | ||
|
||
# if we set the shots to None, backprop can now be used | ||
cost_fn(a, b, shots=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we then pass in diff_method="param-shift"
if we really wanted to use parameter shift with shots=None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep! The reason the internal diff_method is changing here is because, by default, diff_method="best"
. If you instead set diff_method="param-shift"
, then it will not change dynamically.
|
||
if diff_method not in {"backprop"}: | ||
pytest.skip("Test only supports backprop") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this have any effect at this point?
if diff_method not in {"backprop"}: | |
pytest.skip("Test only supports backprop") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would prefer to leave it in, since we may have support for vector valued QNodes in parameter-shift mode at some point in the future!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! Would we only test assert res.dtype is np.dtype("complex128")
for non-backprop diff methods? That seems to be the only statement before skipping the test.
Co-authored-by: antalszava <antalszava@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! 💯 Thank you for adding these tests to check the JAX interface. 😍
Just double-checking this one:
If I recall correctly, we plan to discontinue support for QNodes with ragged outputs (right? thinking). Would that affect the failing case of Ragged QNodes in backprop mode?
This is just to understand better what the priority of the drawbacks would be that were identified in the PR description. 🙂
Yes! Hopefully, this would allow these QNodes to work with JAX, once we make that change 🙂 |
Context: The JAX interface is missing many of the integration tests that the other interfaces have. This PR adds these integration tests in, and makes note of where the JAX interface may be lacking in feature parity.
Description of the Change:
Adds a JAX integration test to
test_gradient_transform.py
(including JIT tests).Adds QNode integration tests
tests/interfaces/test_batch_jax_qnode.py
.Adds a batch transform integration test
test_batch_transform.py
.Modifies
qml.math.get_trainable_indices()
to correctly return results when JAX is performing a forward-only computation. This is done by making the following change:DeviceArray
objects are treated as trainable.jax.core.Tracer
objects are treated as trainable (which matches thejax.grad(cost, argnum=...)
argument).This change allows the metric tensor/gradient transform functions --- which apply in forward-only mode but require knowledge of trainable parameters --- to apply to QNodes when using JAX.
This is required because, since JAX does not have a method of specifying trainable parameters on the forward pass, perviously gradient transforms would simply register no trainable parameters on forward passes, and return an empty list as a result! Paradoxically, differentiating the gradient transform would work fine, since the trainable parameters are now registered.
Benefits:
Better tests for JAX, and a better idea of what works, and what doesn't.
Gradient transforms, and the metric tensor, now works for the JAX interface.
Possible Drawbacks:
Several areas were noticed were JAX usage resulted in errors or issues, unlike other interfaces:
Ragged QNodes in backprop mode. E.g.,
return qml.expval(qml.PauliZ(0)), qml.probs(wires=[1])
. This appears to be because line 230 inQubitDevice
(results = self._asarray(results)
) fails.Hamiltonian expansion of
expval(H)
when using finite shots fails. When using finite shots,expval(H)
results in the tape being expanded to multiple tapes, each one having a vector-valued output. Since the JAX parameter-shift interface does not support vector-valued tapes, an error occurs.JIT mode with the
adjoint
method. Since the adjoint method is not usinghost_callback
in thejax.py
interface, it results in a failure attempting to JIT the QNode.JIT mode with
jax.grad(cost, argnums=...)
whereargnums
is a subset of allowed arguments, e.g.,argnums=[0, 2]
. In this scenario, the JAX interface is only passing unwrapped trainable parameters to thehost_callback
- as a result, non-trainable parameters are remaining asjax.core.Tracer
objects on the tape, which cannot be understood by the device.Related GitHub Issues: n/a